1
超越元素級運算:轉向分塊矩陣運算
AI023Lesson 9
00:00

在先前的課程中,我們專注於 元素級運算 (例如對矩陣執行基本的 ReLU)。這些運算屬於 記憶體限制型 因為顯示卡花費更多時間將資料從高頻寬記憶體(HBM)移動到暫存器,而非進行數學運算。

1. 為何 GEMM 至關重要

一般矩陣乘法(GEMM)的計算複雜度為 $O(N^3)$,但僅需 $O(N^2)$ 的記憶體存取。這讓我們能以巨大的算術吞吐量隱藏記憶體延遲,使其成為大型語言模型(LLMs)的「心臟」。

2. 二維記憶體表示

實際的記憶體是 1 維的。要表示一個二維張量,我們使用 步幅(Stride)。一個常見的生產環境陷阱是 假設張量是連續的。若你在指標運算中混淆了列與行的步幅,將會讀取到「幽靈資料」或引發記憶體違規。

3. 分塊泛化

Triton 透過從 單一指標 轉向 指標塊。藉由使用二維分塊(例如 $16 \times 16$),我們可利用 資料重用 高速 SRAM 中的特性,使資料保持『熱態』,以便在寫回全域記憶體前,進行融合運算,如偏置加法或激活函數。

1D 線性佈局2D 分塊佈局
main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>